{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "from lib.utils import graph_wrapper\n",
    "from transformers import AutoTokenizer, LlamaForCausalLM\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_quantized_model(\n",
    "    model_save_path,\n",
    "    base_model,\n",
    "    device,\n",
    "):\n",
    "    model = torch.load(model_save_path, map_location=device).to(device) # Llama with Caldera\n",
    "    graph_model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=\"cpu\").from_pretrained(\n",
    "            base_model, torch_dtype='auto', device_map=\"cpu\", low_cpu_mem_usage=True,\n",
    "            use_flash_attention_2=True\n",
    "    ).to(\"cpu\") # base Llama\n",
    "\n",
    "    for i in range(len(graph_model.model.layers)):\n",
    "        graph_model.model.layers[i].self_attn.q_proj = model.model.layers[i].self_attn.q_proj\n",
    "        graph_model.model.layers[i].self_attn.k_proj = model.model.layers[i].self_attn.k_proj\n",
    "        graph_model.model.layers[i].self_attn.v_proj = model.model.layers[i].self_attn.v_proj\n",
    "        graph_model.model.layers[i].self_attn.o_proj = model.model.layers[i].self_attn.o_proj\n",
    "        graph_model.model.layers[i].mlp = model.model.layers[i].mlp\n",
    "        graph_model.model.layers[i].post_attention_layernorm = graph_model.model.layers[i].post_attention_layernorm.to(device)\n",
    "        graph_model.model.layers[i].input_layernorm = graph_model.model.layers[i].input_layernorm.to(device)\n",
    "    graph_model.model.norm = graph_model.model.norm.to(device)\n",
    "    graph_model.model.embed_tokens = graph_model.model.embed_tokens.to(device)\n",
    "    graph_model.lm_head = graph_model.lm_head.to(device)\n",
    "    graph_model.graph_device = device\n",
    "    return graph_model.to(device)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Throughput of CALDERA Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_PATH = \"/media/hdd1/caldera-full-models/llama-2-7b/caldera-rank-256-4B-factors-downdate-no-RHT-ft.pt\"\n",
    "BASE_MODEL = \"meta-llama/Llama-2-7b-hf\"\n",
    "DEVICE = \"cuda:2\"\n",
    "SAMPLES = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_quantized_model(MODEL_PATH, BASE_MODEL, DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_throughput(model, samples, base_model, device, batch_size=1, seq_len=1):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(base_model)\n",
    "\n",
    "    prompt = 'It is a truth universally acknowledged that'\n",
    "    inputs = tokenizer(prompt, return_tensors='pt')\n",
    "    token = inputs['input_ids'][0:1, 0:1].to(device).repeat(batch_size, seq_len)\n",
    "    model(token)\n",
    "\n",
    "    torch.cuda.synchronize()\n",
    "    start = time.time()\n",
    "    for _ in range(samples):\n",
    "        model(token)\n",
    "    torch.cuda.synchronize()\n",
    "    end = time.time()\n",
    "    print('TIME:', (end - start) / samples, 's/tok')\n",
    "    print (f'THROUGHPUT: {samples / (end - start)} tok/s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare with Unquantized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del model\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=DEVICE).from_pretrained(\n",
    "            BASE_MODEL, torch_dtype='auto', device_map=DEVICE, low_cpu_mem_usage=True,\n",
    "            use_flash_attention_2=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
